# -*- coding: utf-8 -*-
# Copyright 2025 BrainX Ecosystem Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

from functools import wraps, partial
from typing import Union, Sequence, Dict, Callable, Tuple, Type, Optional, Any

import jax
import numpy as np
import numpy as onp
from jax import numpy as jnp
from jax.lax import cond

conn = None
init = None
var_obs = None

Array = None
BrainPyObject = None

__all__ = [
    'is_checking',
    'turn_on',
    'turn_off',

    'is_shape_consistency',
    'is_shape_broadcastable',
    'check_shape_except_batch',
    'check_shape',
    'is_dict_data',
    'is_callable',
    'is_initializer',
    'is_connector',
    'is_float',
    'is_integer',
    'is_string',
    'is_sequence',
    'is_subclass',
    'is_instance',
    'is_elem_or_seq_or_dict',
    'is_all_vars',
    'is_all_objs',
    'jit_error',
    'jit_error_checking',
    'jit_error_checking_no_args',

    'serialize_kwargs',
]

_check = True
_name_check = True


def is_checking():
    """Whether the checking is turn on."""
    return _check


def turn_on():
    """Turn on the checking."""
    global _check
    _check = True


def turn_off():
    """Turn off the checking."""
    global _check
    _check = False


# def turn_off_name_check

def is_shape_consistency(shapes, free_axes=None, return_format_shapes=False):
    assert isinstance(shapes, (tuple, list)), f'Must be a sequence of shape. While we got {shapes}.'
    for shape in shapes:
        assert isinstance(shapes, (tuple, list)), (f'Must be a sequence of shape. While '
                                                   f'we got one element is {shape}.')
    dims = onp.unique([len(shape) for shape in shapes])
    if len(dims) > 1:
        raise ValueError(f'The provided shape dimensions are not consistent. ')
    if free_axes is None:
        type_ = 'none'
        free_axes = ()
    elif isinstance(free_axes, (tuple, list)):
        type_ = 'seq'
        free_axes = tuple(free_axes)
    elif isinstance(free_axes, int):
        type_ = 'int'
        free_axes = (free_axes,)
    else:
        raise ValueError
    free_axes = [(dims[0] + axis if axis < 0 else axis) for axis in free_axes]
    all_shapes = []
    for shape in shapes:
        assert isinstance(shapes, (tuple, list)), (f'Must be a sequence of shape. While '
                                                   f'we got one element is {shape}.')
        shape = tuple([sh for i, sh in enumerate(shape) if i not in free_axes])
        all_shapes.append(shape)
    unique_shape = tuple(set(all_shapes))
    if len(unique_shape) > 1:
        if len(free_axes):
            raise ValueError(f'The provided shape (without axes of {free_axes}) are not consistent.')
        else:
            raise ValueError(f'The provided shape are not consistent.')
    if return_format_shapes:
        if type_ == 'int':
            free_shapes = tuple([shape[free_axes[0]] for shape in shapes])
        elif type_ == 'seq':
            free_shapes = tuple([tuple([shape[axis] for axis in free_axes]) for shape in shapes])
        else:
            free_shapes = None
        return unique_shape[0], free_shapes


def is_shape_broadcastable(shapes, free_axes=(), return_format_shapes=False):
    """Check whether the given shapes are broadcastable.

    See https://numpy.org/doc/stable/reference/generated/numpy.broadcast.html
    for more details.

    Parameters::

    shapes
    free_axes
    return_format_shapes

    Returns::

    """
    max_dim = max([len(shape) for shape in shapes])
    shapes = [[1] * (max_dim - len(s)) + list(s) for s in shapes]
    return is_shape_consistency(shapes, free_axes, return_format_shapes)


def check_shape_except_batch(shape1, shape2, batch_idx=0, mode='raise'):
    """Check whether two shapes are compatible except the batch size axis."""
    assert mode in ['raise', 'bool']
    if len(shape2) != len(shape1):
        if mode == 'raise':
            raise ValueError(f'Dimension mismatch between two shapes. '
                             f'{shape1} != {shape2}')
        else:
            return False
    new_shape1 = list(shape1)
    new_shape2 = list(shape2)
    new_shape1.pop(batch_idx)
    new_shape2.pop(batch_idx)
    if new_shape1 != new_shape2:
        if mode == 'raise':
            raise ValueError(f'Two shapes {new_shape1} and {new_shape2} are not '
                             f'consistent when excluding the batch axis '
                             f'{batch_idx}')
        else:
            return False
    return True


def check_shape(all_shapes, free_axes: Union[Sequence[int], int] = -1):
    # check "all_shapes"
    if isinstance(all_shapes, dict):
        all_shapes = tuple(all_shapes.values())
    elif isinstance(all_shapes, (tuple, list)):
        all_shapes = tuple(all_shapes)
    else:
        raise ValueError
    # maximum number of dimension
    max_dim = max([len(shape) for shape in all_shapes])
    all_shapes = [[1] * (max_dim - len(s)) + list(s) for s in all_shapes]
    # check "free_axes"
    type_ = 'seq'
    if isinstance(free_axes, int):
        free_axes = (free_axes,)
        type_ = 'int'
    elif isinstance(free_axes, (tuple, list)):
        free_axes = tuple(free_axes)
    assert isinstance(free_axes, tuple)
    free_axes = [(axis + max_dim if axis < 0 else axis) for axis in free_axes]
    fixed_axes = [i for i in range(max_dim) if i not in free_axes]
    # get all free shapes
    if type_ == 'int':
        free_shape = [shape[free_axes[0]] for shape in all_shapes]
    else:
        free_shape = [[shape[axis] for axis in free_axes] for shape in all_shapes]
    # get all assumed fixed shapes
    fixed_shapes = [[shape[axis] for shape in all_shapes] for axis in fixed_axes]
    max_fixed_shapes = [max(shape) for shape in fixed_shapes]
    # check whether they can broadcast compatible
    for i, shape in enumerate(fixed_shapes):
        if len(set(shape) - {1, max_fixed_shapes[i]}):
            raise ValueError(f'Shapes out of axes {free_axes} are not '
                             f'broadcast compatible: \n'
                             f'{all_shapes}')
    return free_shape, max_fixed_shapes


def is_dict_data(a_dict: Dict,
                 key_type: Union[Type, Tuple[Type, ...]] = None,
                 val_type: Union[Type, Tuple[Type, ...]] = None,
                 name: str = None,
                 allow_none: bool = True):
    """Check the dictionary data.
    """
    if allow_none and a_dict is None:
        return None
    name = '' if (name is None) else f'"{name}"'
    if not isinstance(a_dict, dict):
        raise ValueError(f'{name} must be a dict, while we got {type(a_dict)}')
    for key, value in a_dict.items():
        if (key_type is not None) and (not isinstance(key, key_type)):
            raise ValueError(f'{name} must be a dict of ({key_type}, {val_type}), '
                             f'while we got ({type(key)}, {type(value)})')
        if (val_type is not None) and (not isinstance(value, val_type)):
            raise ValueError(f'{name} must be a dict of ({key_type}, {val_type}), '
                             f'while we got ({type(key)}, {type(value)})')
    return a_dict


def is_callable(fun: Callable,
                name: str = None,
                allow_none: bool = False):
    name = '' if name is None else name
    if fun is None:
        if allow_none:
            return None
        else:
            raise ValueError(f'{name} must be a callable function, but we got None.')
    if not callable(fun):
        raise ValueError(f'{name} should be a callable function. While we got {type(fun)}')
    return fun


def is_initializer(
    initializer,
    name: str = None,
    allow_none: bool = False
):
    """Check the initializer.
    """
    global Array
    if Array is None: from brainpy.math.ndarray import Array as Array

    global init
    if init is None:
        from brainpy import initialize
        init = initialize

    name = '' if name is None else name
    if initializer is None:
        if allow_none:
            return
        else:
            raise ValueError(f'{name} must be an initializer, but we got None.')
    if isinstance(initializer, init.Initializer):
        return initializer
    elif isinstance(initializer, (Array, jax.Array)):
        return initializer
    elif callable(initializer):
        return initializer
    else:
        raise ValueError(f'{name} should be an instance of brainpy.init.Initializer, '
                         f'tensor or callable function. While we got {type(initializer)}')


def is_connector(
    connector,
    name: str = None,
    allow_none: bool = False
):
    """Check the connector.
    """
    global Array
    if Array is None:
        from brainpy.math.ndarray import Array as Array
    global conn
    if conn is None: from brainpy import connect as conn

    name = '' if name is None else name
    if connector is None:
        if allow_none:
            return None
        else:
            raise ValueError(f'{name} must be an initializer, but we got None.')
    if isinstance(connector, conn.Connector):
        return connector
    elif isinstance(connector, (Array, jax.Array)):
        return connector
    elif callable(connector):
        return connector
    else:
        raise ValueError(f'{name} should be an instance of brainpy.conn.Connector, '
                         f'tensor or callable function. While we got {type(connector)}')


def is_sequence(
    value: Sequence,
    name: str = None,
    elem_type: Union[type, Sequence[type]] = None,
    allow_none: bool = True
):
    if name is None: name = ''
    if value is None:
        if allow_none:
            return
        else:
            raise ValueError(f'{name} must be a sequence, but got None')
    if not isinstance(value, (tuple, list)):
        raise ValueError(f'{name} should be a sequence, but we got a {type(value)}')
    if elem_type is not None:
        for v in value:
            if not isinstance(v, elem_type):
                raise ValueError(f'Elements in {name} should be {elem_type}, '
                                 f'but we got {type(elem_type)}: {v}')
    return value


def is_float(
    value: float,
    name: str = None,
    min_bound: float = None,
    max_bound: float = None,
    allow_none: bool = False,
    allow_int: bool = True
) -> float:
    """Check float type.

    Parameters::

    value: Any
    name: optional, str
    min_bound: optional, float
      The allowed minimum value.
    max_bound: optional, float
      The allowed maximum value.
    allow_none: bool
      Whether allow the value is None.
    allow_int: bool
      Whether allow the value be an integer.
    """
    if name is None: name = ''
    if value is None:
        if allow_none:
            return None
        else:
            raise ValueError(f'{name} must be a float, but got None')
    if allow_int:
        if not isinstance(value, (float, int, np.integer, np.floating)):
            raise ValueError(f'{name} must be a float, but got {type(value)}')
    else:
        if not isinstance(value, (float, np.floating)):
            raise ValueError(f'{name} must be a float, but got {type(value)}')
    if min_bound is not None:
        jit_error_checking_no_args(value < min_bound,
                                   ValueError(f"{name} must be a float bigger than {min_bound}, "
                                              f"while we got {value}"))

    if max_bound is not None:
        jit_error_checking_no_args(value > max_bound,
                                   ValueError(f"{name} must be a float smaller than {max_bound}, "
                                              f"while we got {value}"))
    return value


def is_integer(value: int, name=None, min_bound=None, max_bound=None, allow_none=False):
    """Check integer type.

    Parameters::

    value: int, optional
    name: optional, str
    min_bound: optional, int
      The allowed minimum value.
    max_bound: optional, int
      The allowed maximum value.
    allow_none: bool
      Whether allow the value is None.
    """
    if name is None: name = ''
    if value is None:
        if allow_none:
            return
        else:
            raise ValueError(f'{name} must be an int, but got None')
    if not isinstance(value, (int, np.integer)):
        if hasattr(value, '__array__'):
            if not (np.issubdtype(value.dtype, np.integer) and value.ndim == 0 and value.size == 1):
                raise ValueError(f'{name} must be an int, but got {value}')
        else:
            raise ValueError(f'{name} must be an int, but got {value}')
    if min_bound is not None:
        jit_error_checking_no_args(jnp.any(value < min_bound),
                                   ValueError(f"{name} must be an int bigger than {min_bound}, "
                                              f"while we got {value}"))
    if max_bound is not None:
        jit_error_checking_no_args(jnp.any(value > max_bound),
                                   ValueError(f"{name} must be an int smaller than {max_bound}, "
                                              f"while we got {value}"))
    return value


def is_string(value: str, name: str = None, candidates: Sequence[str] = None, allow_none=False):
    """Check string type.
    """
    if name is None: name = ''
    if value is None:
        if allow_none:
            return None
        else:
            raise ValueError(f'{name} must be a str, but got None')
    if candidates is not None:
        if value not in candidates:
            raise ValueError(f'{name} must be a str in {candidates}, '
                             f'but we got {value}')
    return value


def serialize_kwargs(shared_kwargs: Optional[Dict]):
    """Serialize kwargs."""
    shared_kwargs = dict() if shared_kwargs is None else shared_kwargs
    is_dict_data(shared_kwargs,
                 key_type=str,
                 val_type=(bool, float, int, complex, str),
                 name='shared_kwargs')
    shared_kwargs = {key: shared_kwargs[key] for key in sorted(shared_kwargs.keys())}
    return str(shared_kwargs)


def is_subclass(
    instance: Any,
    supported_types: Union[Type, Sequence[Type]],
    name: str = ''
) -> None:
    r"""Check whether the instance is in the inheritance tree of the supported types.

    This function is used to check whether the given ``instance`` is an instance of
    parent types in the inheritance hierarchy of the given ``supported_types``.


    Here we have the following inheritance hierarchy::

             A
           /   \
          B     C
         / \   / \
        D   E F   G

    If ``supported_types`` is ``[E, F]``, then

    - the instance of ``D`` or ``G`` will fail to pass the check.
    - the instance of ``E`` or ``F`` will success to pass the check.
    - the instance of ``B`` or ``C`` will also success to pass the check.
    - the instance of ``A`` will success to pass the check too.

    Parameters::

    instance: Any
      The instance in the inheritance hierarchy tree.
    supported_types: type, list of type, tuple of type
      All types that are supported.
    name: str
      The checking target name.
    """
    mode_type = type(instance)
    if isinstance(supported_types, type):
        supported_types = (supported_types,)
    if not isinstance(supported_types, (tuple, list)):
        raise TypeError(f'supported_types must be a tuple/list of type. But wwe got {type(supported_types)}')
    for smode in supported_types:
        if not isinstance(smode, type):
            raise TypeError(f'supported_types must be a tuple/list of type. But wwe got {smode}')
    checking = [issubclass(smode, mode_type) for smode in supported_types]
    if any(checking):
        return instance
    else:
        raise NotImplementedError(f"{name} does not support {instance}. We only support "
                                  f"{', '.join([mode.__name__ for mode in supported_types])}. ")


def is_instance(
    instance: Any,
    supported_types: Union[Type, Sequence[Type]],
    name: str = ''
):
    r"""Check whether the ``instance`` is the instance of the given types.

    This function is used to check whether the given ``instance`` is an instance of
    the given ``supported_types``.

    Here we have the following inheritance hierarchy::

             A
           /   \
          B     C
         / \   / \
        D   E F   G

    If ``supported_types`` is ``[B, F]``, then

    - the instance of ``A`` or ``C`` or ``G`` will fail to pass the check.
    - the instance of ``B`` or ``D`` or ``E`` or ``F`` will success to pass the check.

    Parameters::

    instance: Any
      The instance in the inheritance hierarchy tree.
    supported_types: type, list of type, tuple of type
      All types that are supported.
    name: str
      The checking target name.
    """
    if not name:
        name = 'We'
    if not isinstance(instance, supported_types):
        raise NotImplementedError(f"{name} expect to get an instance of {supported_types}."
                                  f"But we got {type(instance)}. ")
    return instance


def is_elem_or_seq_or_dict(targets: Any,
                           elem_type: Union[type, Tuple[type, ...]],
                           out_as: str = 'tuple'):
    assert out_as in ['tuple', 'list', 'dict', None], 'Only support to output as tuple/list/dict/None'

    if targets is None:
        keys = []
        vals = []
    elif isinstance(targets, elem_type):
        keys = [id(targets)]
        vals = [targets]
    elif isinstance(targets, (list, tuple)):
        is_leaf = [isinstance(l, elem_type) for l in targets]
        if not all(is_leaf):
            raise ValueError(f'Only support {elem_type}, sequence of {elem_type}, or dict of {elem_type}.')
        keys = [id(v) for v in targets]
        vals = list(targets)
    elif isinstance(targets, dict):
        is_leaf = [isinstance(l, elem_type) for l in targets.values()]
        if not all(is_leaf):
            raise ValueError(f'Only support {elem_type}, sequence of {elem_type}, or dict of {elem_type}.')
        keys = list(targets.keys())
        vals = list(targets.values())
    else:
        raise ValueError(f'Only support {elem_type}, sequence of {elem_type}, or dict of {elem_type}.')

    if out_as is None:
        return targets
    elif out_as == 'list':
        return vals
    elif out_as == 'tuple':
        return tuple(vals)
    elif out_as == 'dict':
        return dict(zip(keys, vals))
    else:
        raise KeyError


def is_all_vars(dyn_vars: Any, out_as: str = 'tuple'):
    global var_obs
    if var_obs is None:
        from brainpy.math import Variable, VarList, VarDict
        var_obs = (VarList, VarDict, Variable)

    return is_elem_or_seq_or_dict(dyn_vars, var_obs, out_as)


def is_all_objs(targets: Any, out_as: str = 'tuple'):
    global BrainPyObject
    if BrainPyObject is None:
        from brainpy.math.object_transform.base import BrainPyObject
    return is_elem_or_seq_or_dict(targets, BrainPyObject, out_as)


def _err_jit_true_branch(err_fun, x):
    if isinstance(x, (tuple, list)):
        x_shape_dtype = tuple(jax.ShapeDtypeStruct(arr.shape, arr.dtype) for arr in x)
    else:
        x_shape_dtype = jax.ShapeDtypeStruct(x.shape, x.dtype)
    jax.pure_callback(err_fun, x_shape_dtype, x, vmap_method='sequential')
    return


def _err_jit_false_branch(x):
    return


def _cond(err_fun, pred, err_arg):
    from brainpy.math.remove_vmap import remove_vmap

    @wraps(err_fun)
    def true_err_fun(arg, transforms):
        err_fun(arg)

    cond(remove_vmap(pred),
         partial(_err_jit_true_branch, true_err_fun),
         _err_jit_false_branch,
         err_arg)


def jit_error(pred, err_fun, err_arg=None):
    """Check errors in a jit function.

    Parameters::

    pred: bool, Array
      The boolean prediction.
    err_fun: callable
      The error function, which raise errors.
    err_arg: any
      The arguments which passed into `err_f`.
    """
    from brainpy.math.interoperability import as_jax
    partial(_cond, err_fun)(as_jax(pred), err_arg)


jit_error_checking = jit_error


def jit_error_checking_no_args(pred: bool, err: Exception):
    """Check errors in a jit function.

    Parameters::

    pred: bool
      The boolean prediction.
    err: Exception
      The error.
    """
    from brainpy.math.remove_vmap import remove_vmap
    from brainpy.math.interoperability import as_jax

    assert isinstance(err, Exception), 'Must be instance of Exception.'

    def true_err_fun(arg, transforms):
        raise err

    cond(remove_vmap(as_jax(pred)),
         lambda: jax.pure_callback(true_err_fun, None),
         lambda: None)

